from inspect import signature

import numpy as np
import torch
from coqpit import Coqpit
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from trainer.io import load_fsspec
from trainer.trainer_utils import get_optimizer, get_scheduler

from TTS.utils.audio import AudioProcessor
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
from TTS.vocoder.models import setup_discriminator, setup_generator
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.generic_utils import plot_results


class GAN(BaseVocoder):
    def __init__(self, config: Coqpit, ap: AudioProcessor = None):
        """Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
        It also helps mixing and matching different generator and disciminator networks easily.

        To implement a new GAN models, you just need to define the generator and the discriminator networks, the rest
        is handled by the `GAN` class.

        Args:
            config (Coqpit): Model configuration.
            ap (AudioProcessor): 🐸TTS AudioProcessor instance. Defaults to None.

        Examples:
            Initializing the GAN model with HifiGAN generator and discriminator.
            >>> from TTS.vocoder.configs import HifiganConfig
            >>> config = HifiganConfig()
            >>> model = GAN(config)
        """
        super().__init__(config)
        self.config = config
        self.model_g = setup_generator(config)
        self.model_d = setup_discriminator(config)
        self.train_disc = False  # if False, train only the generator.
        self.y_hat_g = None  # the last generator prediction to be passed onto the discriminator
        self.ap = ap

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Run the generator's forward pass.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: output of the GAN generator network.
        """
        return self.model_g.forward(x)

    def inference(self, x: torch.Tensor) -> torch.Tensor:
        """Run the generator's inference pass.

        Args:
            x (torch.Tensor): Input tensor.
        Returns:
            torch.Tensor: output of the GAN generator network.
        """
        return self.model_g.inference(x)

    def train_step(self, batch: dict, criterion: dict, optimizer_idx: int) -> tuple[dict, dict]:
        """Compute model outputs and the loss values. `optimizer_idx` selects the generator or the discriminator for
        network on the current pass.

        Args:
            batch (Dict): Batch of samples returned by the dataloader.
            criterion (Dict): Criterion used to compute the losses.
            optimizer_idx (int): ID of the optimizer in use on the current pass.

        Raises:
            ValueError: `optimizer_idx` is an unexpected value.

        Returns:
            Tuple[Dict, Dict]: model outputs and the computed loss values.
        """
        outputs = {}
        loss_dict = {}

        x = batch["input"]
        y = batch["waveform"]

        if optimizer_idx not in [0, 1]:
            raise ValueError(" [!] Unexpected `optimizer_idx`.")

        if optimizer_idx == 0:
            # DISCRIMINATOR optimization

            # generator pass
            y_hat = self.model_g(x)[:, :, : y.size(2)]

            # cache for generator loss
            # pylint: disable=W0201
            self.y_hat_g = y_hat
            self.y_hat_sub = None
            self.y_sub_g = None

            # PQMF formatting
            if y_hat.shape[1] > 1:
                self.y_hat_sub = y_hat
                y_hat = self.model_g.pqmf_synthesis(y_hat)
                self.y_hat_g = y_hat  # save for generator loss
                self.y_sub_g = self.model_g.pqmf_analysis(y)

            scores_fake, feats_fake, feats_real = None, None, None

            if self.train_disc:
                # use different samples for G and D trainings
                if self.config.diff_samples_for_G_and_D:
                    x_d = batch["input_disc"]
                    y_d = batch["waveform_disc"]
                    # use a different sample than generator
                    with torch.no_grad():
                        y_hat = self.model_g(x_d)

                    # PQMF formatting
                    if y_hat.shape[1] > 1:
                        y_hat = self.model_g.pqmf_synthesis(y_hat)
                else:
                    # use the same samples as generator
                    x_d = x.clone()
                    y_d = y.clone()
                    y_hat = self.y_hat_g

                # run D with or without cond. features
                if len(signature(self.model_d.forward).parameters) == 2:
                    D_out_fake = self.model_d(y_hat.detach().clone(), x_d)
                    D_out_real = self.model_d(y_d, x_d)
                else:
                    D_out_fake = self.model_d(y_hat.detach())
                    D_out_real = self.model_d(y_d)

                # format D outputs
                if isinstance(D_out_fake, tuple):
                    # self.model_d returns scores and features
                    scores_fake, feats_fake = D_out_fake
                    if D_out_real is None:
                        scores_real, feats_real = None, None
                    else:
                        scores_real, feats_real = D_out_real
                else:
                    # model D returns only scores
                    scores_fake = D_out_fake
                    scores_real = D_out_real

                # compute losses
                loss_dict = criterion[optimizer_idx](scores_fake, scores_real)
                outputs = {"model_outputs": y_hat}

        if optimizer_idx == 1:
            # GENERATOR loss
            scores_fake, feats_fake, feats_real = None, None, None
            if self.train_disc:
                if len(signature(self.model_d.forward).parameters) == 2:
                    D_out_fake = self.model_d(self.y_hat_g, x)
                else:
                    D_out_fake = self.model_d(self.y_hat_g)
                D_out_real = None

                if self.config.use_feat_match_loss:
                    with torch.no_grad():
                        D_out_real = self.model_d(y)

                # format D outputs
                if isinstance(D_out_fake, tuple):
                    scores_fake, feats_fake = D_out_fake
                    if D_out_real is None:
                        feats_real = None
                    else:
                        _, feats_real = D_out_real
                else:
                    scores_fake = D_out_fake
                    feats_fake, feats_real = None, None

            # compute losses
            loss_dict = criterion[optimizer_idx](
                self.y_hat_g, y, scores_fake, feats_fake, feats_real, self.y_hat_sub, self.y_sub_g
            )
            outputs = {"model_outputs": self.y_hat_g}
        return outputs, loss_dict

    def _log(self, name: str, ap: AudioProcessor, batch: dict, outputs: dict) -> tuple[dict, dict]:
        """Logging shared by the training and evaluation.

        Args:
            name (str): Name of the run. `train` or `eval`,
            ap (AudioProcessor): Audio processor used in training.
            batch (Dict): Batch used in the last train/eval step.
            outputs (Dict): Model outputs from the last train/eval step.

        Returns:
            Tuple[Dict, Dict]: log figures and audio samples.
        """
        y_hat = outputs[0]["model_outputs"] if self.train_disc else outputs[1]["model_outputs"]
        y = batch["waveform"]
        figures = plot_results(y_hat, y, ap, name)
        sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
        audios = {f"{name}/audio": sample_voice}
        return figures, audios

    def train_log(
        self,
        batch: dict,
        outputs: dict,
        logger: "Logger",
        assets: dict,
        steps: int,  # pylint: disable=unused-argument
    ) -> tuple[dict, np.ndarray]:
        """Call `_log()` for training."""
        figures, audios = self._log("eval", self.ap, batch, outputs)
        logger.eval_figures(steps, figures)
        logger.eval_audios(steps, audios, self.ap.sample_rate)

    @torch.inference_mode()
    def eval_step(self, batch: dict, criterion: nn.Module, optimizer_idx: int) -> tuple[dict, dict]:
        """Call `train_step()` with `no_grad()`"""
        self.train_disc = True  # Avoid a bug in the Training with the missing discriminator loss
        return self.train_step(batch, criterion, optimizer_idx)

    def eval_log(
        self,
        batch: dict,
        outputs: dict,
        logger: "Logger",
        assets: dict,
        steps: int,  # pylint: disable=unused-argument
    ) -> tuple[dict, np.ndarray]:
        """Call `_log()` for evaluation."""
        figures, audios = self._log("eval", self.ap, batch, outputs)
        logger.eval_figures(steps, figures)
        logger.eval_audios(steps, audios, self.ap.sample_rate)

    def load_checkpoint(
        self,
        config: Coqpit,
        checkpoint_path: str,
        eval: bool = False,  # pylint: disable=unused-argument, redefined-builtin
        cache: bool = False,
    ) -> None:
        """Load a GAN checkpoint and initialize model parameters.

        Args:
            config (Coqpit): Model config.
            checkpoint_path (str): Checkpoint file path.
            eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
        """
        state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
        # band-aid for older than v0.0.15 GAN models
        if "model_disc" in state:
            self.model_g.load_checkpoint(config, checkpoint_path, eval=eval)
        else:
            self.load_state_dict(state["model"])
            if eval:
                self.model_d = None
                if hasattr(self.model_g, "remove_weight_norm"):
                    self.model_g.remove_weight_norm()

    def on_train_step_start(self, trainer) -> None:
        """Enable the discriminator training based on `steps_to_start_discriminator`

        Args:
            trainer (Trainer): Trainer object.
        """
        self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator

    def get_optimizer(self) -> list:
        """Initiate and return the GAN optimizers based on the config parameters.

        It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.

        Returns:
            List: optimizers.
        """
        optimizer_d = get_optimizer(
            self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.model_d
        )
        optimizer_g = get_optimizer(
            self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g
        )
        return [optimizer_d, optimizer_g]

    def get_lr(self) -> list:
        """Set the initial learning rates for each optimizer.

        Returns:
            List: learning rates for each optimizer.
        """
        return [self.config.lr_disc, self.config.lr_gen]

    def get_scheduler(self, optimizer) -> list:
        """Set the schedulers for each optimizer.

        Args:
            optimizer (List[`torch.optim.Optimizer`]): List of optimizers.

        Returns:
            List: Schedulers, one for each optimizer.
        """
        scheduler_d = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[0])
        scheduler_g = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[1])
        return [scheduler_d, scheduler_g]

    @staticmethod
    def format_batch(batch: list) -> dict:
        """Format the batch for training.

        Args:
            batch (List): Batch out of the dataloader.

        Returns:
            Dict: formatted model inputs.
        """
        if isinstance(batch[0], list):
            x_G, y_G = batch[0]
            x_D, y_D = batch[1]
            return {"input": x_G, "waveform": y_G, "input_disc": x_D, "waveform_disc": y_D}
        x, y = batch
        return {"input": x, "waveform": y}

    def get_data_loader(  # pylint: disable=no-self-use, unused-argument
        self,
        config: Coqpit,
        assets: dict,
        is_eval: True,
        samples: list,
        verbose: bool,
        num_gpus: int,
        rank: int | None = None,  # pylint: disable=unused-argument
    ):
        """Initiate and return the GAN dataloader.

        Args:
            config (Coqpit): Model config.
            ap (AudioProcessor): Audio processor.
            is_eval (True): Set the dataloader for evaluation if true.
            samples (List): Data samples.
            verbose (bool): Log information if true.
            num_gpus (int): Number of GPUs in use.
            rank (int): Rank of the current GPU. Defaults to None.

        Returns:
            DataLoader: Torch dataloader.
        """
        dataset = GANDataset(
            ap=self.ap,
            items=samples,
            seq_len=config.seq_len,
            hop_len=self.ap.hop_length,
            pad_short=config.pad_short,
            conv_pad=config.conv_pad,
            return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
            is_training=not is_eval,
            return_segments=not is_eval,
            use_noise_augment=config.use_noise_augment,
            use_cache=config.use_cache,
        )
        dataset.shuffle_mapping()
        sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
        loader = DataLoader(
            dataset,
            batch_size=1 if is_eval else config.batch_size,
            shuffle=num_gpus == 0,
            drop_last=False,
            sampler=sampler,
            num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
            pin_memory=False,
        )
        return loader

    def get_criterion(self):
        """Return criterions for the optimizers"""
        return [DiscriminatorLoss(self.config), GeneratorLoss(self.config)]

    @staticmethod
    def init_from_config(config: Coqpit) -> "GAN":
        ap = AudioProcessor.init_from_config(config)
        return GAN(config, ap=ap)
